iT邦幫忙

2022 iThome 鐵人賽

DAY 6
0
AI & Data

JAX 好好玩系列 第 6

JAX 好好玩 (6) : JAX.NUMPY (2) : 虛擬亂數產生器

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

雖然 jax.numpy 是依據 Numpy 的語法和語義來設計的,但仍有幾個不同點需要注意。第一個要介紹的是「虛擬亂數產生器 (Pseudo Random Number Generator; PRNG)」。

Numpy 的 PRNG

在 Numpy 中要產生亂數,所使用的演算法叫 MT19937,它是「梅森旋轉法 Mersenne Twister」的一個變體 [6.1]。 亂數是藉由一個「全域狀態 (global state) 」來產生的,每一個亂數產生的操作,都會導致全域狀態的改變。

在這裏大家要注意「全域」這兩個字,意指一個 Python 進程 (process) 中的所有執行緒 (thread) 皆共用同一個全域狀態,不同執行緒之間的亂數產生行為彼此相互影響,這個是不個不太好的設計。

Numpy 提供 np.random.seed() 來設定全域狀態的初始值,例如:

np.random.seed(0)

全域狀態是由 624 個 32-bit unsigned int (產生亂數用) 和數個屬性參數 [6.2] 所構成 (在此不細談),Numpy 提供 np.random.get_state() 來獲取全域狀態的內容:

prng_state = np.random.get_state()
index = 0
for element in prng_state:
    print(f'Tuple {index} type: {type(element)} - \
    {element.shape if isinstance(element,np.ndarray) else ""}')
    index += 1

output:
Tuple 0 type: <class 'str'>
Tuple 1 type: <class 'numpy.ndarray'> - (624,)
Tuple 2 type: <class 'int'>
Tuple 3 type: <class 'int'>
Tuple 4 type: <class 'float'>

全域狀態唯一決定了下一個亂數所產生的值,從以下的程式片斷可以看出來,只要重設全域狀態至相同的初始值,其後產生的亂數都是相同的。

np.random.seed(7)
print(np.random.uniform())
print(np.random.uniform())

print("==============================")
np.random.seed(7)
print(np.random.uniform())
print(np.random.uniform())

output:
0.07630828937395717
0.7799187922401146
'=============================='
0.07630828937395717
0.7799187922401146

MT19937 有以下的缺點,致使 JAX 在設計時,決定捨棄它,採用更新的方法。

  • 因為採用全域狀態,因此在多執行緒的程式中,比較難複製跟亂數相關的程式異常狀況。
  • 需要 2.5Kb 的空間在存放全域狀態。
  • (相較於 jax.numpy 的新方法) 它的執行速度比較慢。
  • 最重要的,它無法通過新的 Big Crush [6.3] 測試。

JAX 的 PRNG

我們先用一個例子來說明 JAX PRNG 的使用方法。

from jax import random

# get the key
key = random.PRNGKey(0)
# split the key
key, subkey = random.split(key)
# use subkey
print(random.normal(subkey, shape=(2,)))

# splict the key again
key, subkey = random.split(key)
# usb subkey 
print(random.normal(subkey, shape=(2,)))

output:
[ 0.19307722 -0.52678293]
[ 0.00870701 -0.04888523]

首先要注意的是,JAX 有關亂數產生的 API 是放在 jax.random 之下的,而非 jax.numpy 下:

from jax import random

使用前,要先產生一個 key ,JAX PRNG 是利用 key 來產生亂數,每一個亂數生成 API 都需要輸入 key 值。

key = random.PRNGKey(0)

key 不要直接用,要先用 jax.randdom.split() API 分割成兩個 key [6.4],其中一個 (key) 保留起來,另外一個 (subkey) 可以用來產生亂數。

key, subkey = random.split(key)
print(random.normal(subkey, shape=(2,)))

要再一次產生亂數前,要先分割上次保留的 key,保留一個,使用一個,如此生生不息。

key, subkey = random.split(key)
print(random.normal(subkey, shape=(2,)))

以下的流程圖,詳細說明了 JAX PRNG 的使用方法:
https://ithelp.ithome.com.tw/upload/images/20220916/20129616Ec4lAb2crr.png

我們可以總結 JAX PRNG 不同於 Numpy 的特性:

  • JAX 沒有「全域狀態」,而是利用 key 來產生亂數。
  • 自己的 key 自己生,程式內不同的功能模塊乃至執行時不同的執行緒,都應維護自己的 key 。
  • 自己的 key 自己變,JAX 不會幫你改變 key,你必須遵循 JAX 的 key 分割 (split) 原則來處理你的 key。
  • 同樣的 key 會產生同樣的亂數,所以千萬不要重複使用相同的 key ,否則你的亂數就不是真的「亂」。

註:

[6.1] 參考 維基百科
[6.2] 可以參考 numpy 的使用者手冊
[6.3] 有關 Big Crush 可參考 http://www.iro.umontreal.ca/~lecuyer/myftp/papers/testu01.pdf 。另外可以參考 這份報告 ,在 R 的 MT19937 實作上,會有 2 個 Big Crush 項目失敗。
[6.4] 分割成兩個 key 是一個過份簡化的說法,其實 split() 可以用參數指定分割後的數量,在後續的貼文中老頭會加以說明。


上一篇
JAX 好好玩 (5) : JAX.NUMPY (1) : 一個更好的 Numpy
下一篇
JAX 好好玩 (7) : JAX.NUMPY (3) : 再探 JAX PRNG
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言